import copy
import random
from tqdm import tqdm
import numpy as np
import os
import logging
from datetime import datetime
import json

__all__ = ["EvolutionFinder"]


class ArchManager:
    def __init__(self):
        self.num_blocks = 20
        self.num_stages = 5
        self.kernel_sizes = [3, 5, 7]
        self.expand_ratios = [3, 4, 6]
        self.depths = [2, 3, 4]
        self.resolutions = [160, 176, 192, 208, 224]

    def random_sample(self):
        sample = {}
        d = []
        e = []
        ks = []
        for i in range(self.num_stages):
            d.append(random.choice(self.depths))

        for i in range(self.num_blocks):
            e.append(random.choice(self.expand_ratios))
            ks.append(random.choice(self.kernel_sizes))

        sample = {
            "wid": None,
            "ks": ks,
            "e": e,
            "d": d,
            "r": [random.choice(self.resolutions)],
        }

        return sample

    def random_resample(self, sample, i):
        assert i >= 0 and i < self.num_blocks
        sample["ks"][i] = random.choice(self.kernel_sizes)
        sample["e"][i] = random.choice(self.expand_ratios)

    def random_resample_depth(self, sample, i):
        assert i >= 0 and i < self.num_stages
        sample["d"][i] = random.choice(self.depths)

    def random_resample_resolution(self, sample):
        sample["r"][0] = random.choice(self.resolutions)


class EvolutionFinder:
    def __init__(
        self,
        efficiency_constraint,
        efficiency_predictor,
        accuracy_predictor,
        logger=None,
        **kwargs
    ):
        self.efficiency_constraint = efficiency_constraint
        self.efficiency_predictor = efficiency_predictor
        self.accuracy_predictor = accuracy_predictor
        self.arch_manager = ArchManager()
        self.num_blocks = self.arch_manager.num_blocks
        self.num_stages = self.arch_manager.num_stages
        self.logger = logger

        self.mutate_prob = kwargs.get("mutate_prob", 0.1)
        self.population_size = kwargs.get("population_size", 100)
        self.max_time_budget = kwargs.get("max_time_budget", 500)
        self.parent_ratio = kwargs.get("parent_ratio", 0.25)
        self.mutation_ratio = kwargs.get("mutation_ratio", 0.5)
        self.seed = kwargs.get("seed", 0)

    def set_efficiency_constraint(self, new_constraint):
        self.efficiency_constraint = new_constraint

    def random_sample(self, query_counter=None):
        constraint = self.efficiency_constraint
        while True:
            sample = self.arch_manager.random_sample()
            if query_counter is not None:
                query_counter[0] += 1
            efficiency = self.efficiency_predictor.predict_efficiency(sample)
            if efficiency <= constraint:
                acc = self.accuracy_predictor.predict_accuracy([sample])[0].item()
                return sample, efficiency, acc

    def mutate_sample(self, sample, query_counter=None):
        constraint = self.efficiency_constraint
        while True:
            new_sample = copy.deepcopy(sample)

            if random.random() < self.mutate_prob:
                self.arch_manager.random_resample_resolution(new_sample)

            for i in range(self.num_blocks):
                if random.random() < self.mutate_prob:
                    self.arch_manager.random_resample(new_sample, i)

            for i in range(self.num_stages):
                if random.random() < self.mutate_prob:
                    self.arch_manager.random_resample_depth(new_sample, i)
            
            if query_counter is not None:
                query_counter[0] += 1
                
            efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
            if efficiency <= constraint:
                acc = self.accuracy_predictor.predict_accuracy([new_sample])[0].item()
                return new_sample, efficiency, acc

    def crossover_sample(self, sample1, sample2, query_counter=None):
        constraint = self.efficiency_constraint
        while True:
            new_sample = copy.deepcopy(sample1)
            for key in new_sample.keys():
                if not isinstance(new_sample[key], list):
                    continue
                for i in range(len(new_sample[key])):
                    new_sample[key][i] = random.choice(
                        [sample1[key][i], sample2[key][i]]
                    )
            
            if query_counter is not None:
                query_counter[0] += 1
                
            efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
            if efficiency <= constraint:
                acc = self.accuracy_predictor.predict_accuracy([new_sample])[0].item()
                return new_sample, efficiency, acc

    def run_evolution_search(self):
        """Run a single roll-out of regularized evolution to a fixed time budget."""
        max_time_budget = self.max_time_budget
        population_size = self.population_size
        mutation_numbers = int(round(self.mutation_ratio * population_size))
        parents_size = int(round(self.parent_ratio * population_size))
        constraint = self.efficiency_constraint
        
        self.logger.info("Starting evolution search...")
        self.logger.info(f"Population size: {population_size}")
        self.logger.info(f"Parents size: {parents_size}")
        self.logger.info(f"Mutation numbers: {mutation_numbers}")
        self.logger.info(f"Constraint: {constraint}")
        
        best_info = None
        query_counter = [0]
        global_best_arch_records = []
        
        population = []
        
        for _ in range(self.population_size):
            sample, efficiency, acc = self.random_sample(query_counter)
            population.append([acc, sample, efficiency])
            
            if best_info is None or acc > best_info[0]:
                best_info = [acc, sample, efficiency, query_counter[0]]
                self.logger.info(f"Initial population: New best architecture found at query {query_counter[0]} with accuracy {acc:.4f}, efficiency {efficiency:.2f}")
        
        global_best_arch_records.append((0, best_info[0], best_info[2], best_info[3]))
        
        population = sorted(population, key=lambda x: x[0])[::-1]
        
        for iter in tqdm(range(max_time_budget)):
            iter_start_query = query_counter[0]
            new_best_in_iter = False
            new_best_query = 0
            
            parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
            current_best_acc = parents[0][0]
            
            self.logger.info(f"\nIteration {iter}:")
            self.logger.info(f"Parent best acc: {current_best_acc:.4f}, efficiency: {parents[0][2]:.2f}, queries: {query_counter[0]}")
            self.logger.info(f"Global best so far: ACC {best_info[0]:.4f}, FLOPs {best_info[2]:.2f} at query {best_info[3]}")
            for rank, [acc, arch, eff] in enumerate(population[:20], 1):
                self.logger.info(f"#{rank:02d} - ACC: {acc:.4f} EFFICIENCY: {eff:.2f}")
                self.logger.info(f"Architecture: {arch}")
            
            new_population = []
            
            for i in range(mutation_numbers):
                par_sample = parents[np.random.randint(parents_size)][1]
                new_sample, efficiency, acc = self.mutate_sample(par_sample, query_counter)
                new_population.append([acc, new_sample, efficiency])
                
                if acc > best_info[0]:
                    best_info = [acc, new_sample, efficiency, query_counter[0]]
                    new_best_in_iter = True
                    new_best_query = query_counter[0]
                    self.logger.info(f"Iteration {iter}: New global best architecture found at query {query_counter[0]} with accuracy {acc:.4f}, efficiency {efficiency:.2f}")
            
            for i in range(population_size - mutation_numbers):
                par_sample1 = parents[np.random.randint(parents_size)][1]
                par_sample2 = parents[np.random.randint(parents_size)][1]
                new_sample, efficiency, acc = self.crossover_sample(par_sample1, par_sample2, query_counter)
                new_population.append([acc, new_sample, efficiency])
                
                if acc > best_info[0]:
                    best_info = [acc, new_sample, efficiency, query_counter[0]]
                    new_best_in_iter = True
                    new_best_query = query_counter[0]
                    self.logger.info(f"Iteration {iter}: New global best architecture found at query {query_counter[0]} with accuracy {acc:.4f}, efficiency {efficiency:.2f}")
            
            population = sorted(parents + new_population, key=lambda x: x[0])[::-1][:population_size]
            
            iter_query_info = f"New best found at query {new_best_query}" if new_best_in_iter else "No new best found"
            self.logger.info(f"Iteration {iter} completed: {query_counter[0] - iter_start_query} new architectures tried")
            self.logger.info(f"Global best after iteration {iter}: accuracy {best_info[0]:.4f}, FLOPs {best_info[2]:.2f} at query {best_info[3]} - {iter_query_info}")
            
            global_best_arch_records.append((iter, best_info[0], best_info[2], best_info[3]))

        self.logger.info("\nEvolution search completed!")
        self.logger.info(f"Total queries: {query_counter[0]}")
        self.logger.info(f"Best architecture found:")
        self.logger.info(f"Accuracy: {best_info[0]:.4f}")
        self.logger.info(f"Efficiency: {best_info[2]:.2f}")
        self.logger.info(f"Architecture: {best_info[1]}")
        self.logger.info(f"Found at query: {best_info[3]}")
        
        self.logger.info("\nGlobal best architectures after each iteration:")
        for iter_num, acc, flops, query_num in global_best_arch_records:
            self.logger.info(f"After iteration {iter_num}: Best ACC {acc:.4f}, FLOPs {flops:.2f} at query {query_num}")

        return best_info
